Skip to content

Tseah/convert torchft replica group#60820

Draft
TimothySeah wants to merge 22 commits intoray-project:masterfrom
TimothySeah:tseah/convert-torchft-replica-group
Draft

Tseah/convert torchft replica group#60820
TimothySeah wants to merge 22 commits intoray-project:masterfrom
TimothySeah:tseah/convert-torchft-replica-group

Conversation

@TimothySeah
Copy link
Contributor

Thank you for contributing to Ray! 🚀
Please review the Ray Contribution Guide before opening a pull request.

⚠️ Remove these instructions before submitting your PR.

💡 Tip: Mark as draft if you want early feedback, or ready for review when it's complete.

Description

Briefly describe what this PR accomplishes and why it's needed.

Related issues

Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234".

Additional information

Optional: Add implementation details, API changes, usage examples, screenshots, etc.

aslonnie and others added 22 commits December 16, 2025 14:30
Signed-off-by: Lonnie Liu <lonnie@anyscale.com>
cherrypick ray-project#59494

Signed-off-by: Lonnie Liu <lonnie@anyscale.com>
… you request 0 GPUs on CPU-only cluster (ray-project#59516)

Cherry-pick of ray-project#59514

Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
…ct#59519)

EWMA_ALPHA
Update EWMA_ALPHA from 0.2->0.1. This makes adjusting level to be more in-favor of limiting concurrency by being more sensitive to downstreaming queuing.

K_DEV
Update K_DEV from 2.0->1.0. This makes stddev to be more in-favor of limiting concurrency by being more sensitive to downstreaming queuing.

cherry-pick of ray-project#59392
…oject#59606)

Created by release automation bot.

Update with commit 0de2118

Signed-off-by: Lonnie Liu <lonnie@anyscale.com>
Co-authored-by: Lonnie Liu <lonnie@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for torchft in Ray Train, enabling more granular fault tolerance through replica groups. The changes are extensive, touching configuration, backend setup, controller logic, and worker group management. Key additions include the TorchftConfig, a new _TorchftBackend that manages per-replica-group process groups, and the logic in WorkerGroup to replace failed replica groups. The PR also includes several improvements to Ray Data's backpressure mechanism and a bug fix in the autoscaling coordinator.

Overall, this is a significant feature addition. I've found one critical issue related to the autoscaling/resizing logic and one opportunity for refactoring to reduce code duplication.

Comment on lines +180 to +201
async def _execute_resize_decision(
self, decision: ResizeDecision
) -> TrainControllerLoopIterationResult:
"""Executes resize decisions."""

for callback in self._controller_callbacks:
callback.before_controller_execute_resize_decision(decision)

if self._worker_group:
self._shutdown_worker_group()
optional_controller_error = None

optional_controller_error = self._start_worker_group(
num_workers=decision.num_workers,
resources_per_worker=decision.resources_per_worker,
)
if self._worker_group:
# Replace bad workers in the existing worker group
# TODO: propagate poll_status rather than recalculating it
try:
self._replace_bad_workers(await self._poll_workers())
except Exception as e:
optional_controller_error = ControllerError(e)
else:
optional_controller_error = self._start_worker_group(
num_workers=decision.num_workers,
resources_per_worker=decision.resources_per_worker,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic in _execute_resize_decision has been changed from handling resizing to handling failure recovery. The original implementation, which performed a full restart of the worker group to apply a ResizeDecision, has been replaced with a call to _replace_bad_workers. This new logic only recovers failed workers and does not adjust the total number of workers, which effectively breaks autoscaling (both scaling up and down).

If the intention is to disable scaling when using torchft, this should be handled more explicitly, for instance by ensuring the ScalingPolicy does not generate a ResizeDecision. As it stands, this change is a regression in functionality.

Comment on lines +754 to +828
# Re-initialize backend (per-group TCPStore + init_process_group)
# via BackendSetupCallback
from ray.train.v2._internal.callbacks.backend_setup import BackendSetupCallback

for callback in self._callbacks:
if isinstance(callback, BackendSetupCallback):
# First update workers in state so the callback can access them
new_workers_by_rank = {
w.distributed_context.world_rank: w for w in new_workers
}
updated_workers = [
new_workers_by_rank.get(w.distributed_context.world_rank, w)
for w in workers
]
self._worker_group_state = WorkerGroupState(
start_time=self._worker_group_state.start_time,
placement_group=pg,
workers=updated_workers,
sync_actor=sync_actor,
)
callback.reinitialize_workers(self, target_group.world_ranks)
break

# Get train context args from callbacks
train_context_args = {}
for cb in self._callbacks:
args = cb.before_init_train_context(new_workers)
for arg, arg_values in args.items():
assert len(arg_values) == len(new_workers), (
f"Callback {cb} returned {arg} with "
f"{len(arg_values)} values, expected {len(new_workers)}."
)
assert (
arg not in train_context_args
), f"Callback {cb} returned {arg} which is already set."
train_context_args[arg] = arg_values

# Initialize train context on new workers
try:
self._init_train_context_on_workers(
new_workers, sync_actor, train_context_args
)
except RayActorError as actor_error:
for worker in new_workers:
ray.kill(worker.actor)
error_msg = (
"Replacement workers failed during train context initialization."
)
raise WorkerGroupStartupFailedError(error_msg) from actor_error

# Launch training function on new workers
ray_get_safe(
[
worker.actor.run_train_fn.remote(
self._worker_group_context.train_fn_ref
)
for worker in new_workers
]
)

# Update state if not already updated above (in case no BackendSetupCallback)
if not any(isinstance(cb, BackendSetupCallback) for cb in self._callbacks):
new_workers_by_rank = {
w.distributed_context.world_rank: w for w in new_workers
}
updated_workers = [
new_workers_by_rank.get(w.distributed_context.world_rank, w)
for w in workers
]
self._worker_group_state = WorkerGroupState(
start_time=self._worker_group_state.start_time,
placement_group=pg,
workers=updated_workers,
sync_actor=sync_actor,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to update self._worker_group_state with the new workers is duplicated. It appears once inside the loop that finds the BackendSetupCallback and again in a separate if block for the case where the callback is not found. This can be refactored to update the state once before the loop, improving code clarity and maintainability.

        # Update worker group state with the new workers.
        new_workers_by_rank = {
            w.distributed_context.world_rank: w for w in new_workers
        }
        updated_workers = [
            new_workers_by_rank.get(w.distributed_context.world_rank, w)
            for w in workers
        ]
        self._worker_group_state = WorkerGroupState(
            start_time=self._worker_group_state.start_time,
            placement_group=pg,
            workers=updated_workers,
            sync_actor=sync_actor,
        )

        # Re-initialize backend (per-group TCPStore + init_process_group)
        # via BackendSetupCallback
        from ray.train.v2._internal.callbacks.backend_setup import BackendSetupCallback

        for callback in self._callbacks:
            if isinstance(callback, BackendSetupCallback):
                callback.reinitialize_workers(self, target_group.world_ranks)
                break

        # Get train context args from callbacks
        train_context_args = {}
        for cb in self._callbacks:
            args = cb.before_init_train_context(new_workers)
            for arg, arg_values in args.items():
                assert len(arg_values) == len(new_workers), (
                    f"Callback {cb} returned {arg} with "
                    f"{len(arg_values)} values, expected {len(new_workers)}."
                )
                assert (
                    arg not in train_context_args
                ), f"Callback {cb} returned {arg} which is already set."
                train_context_args[arg] = arg_values

        # Initialize train context on new workers
        try:
            self._init_train_context_on_workers(
                new_workers, sync_actor, train_context_args
            )
        except RayActorError as actor_error:
            for worker in new_workers:
                ray.kill(worker.actor)
            error_msg = (
                "Replacement workers failed during train context initialization."
            )
            raise WorkerGroupStartupFailedError(error_msg) from actor_error

        # Launch training function on new workers
        ray_get_safe(
            [
                worker.actor.run_train_fn.remote(
                    self._worker_group_context.train_fn_ref
                )
                for worker in new_workers
            ]
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Ray fails to serialize self-reference objects

5 participants